import torch
import numpy as np

from abc import abstractmethod
from tqdm import tqdm
from torch.optim import Optimizer

reps_rel = 1e-6

class TorchModel:
    def __init__(self, n_features: int, init_point: torch.Tensor = None):
        torch.set_printoptions(14)
        self.init_point = init_point
        self.N = n_features
        if self.init_point is None:
            self.x = torch.randn(self.N, requires_grad=True, dtype=torch.double)
        else:
            self.x = self.init_point.clone().detach()
            self.x.requires_grad = True
        self.optimizer = None
        self.scheduler = None
        self.best_point = self.x.clone().detach()
        self.best_value = None

        self.last_step_losses = []

    def closure(self):
        loss = self.loss_with_oracle_call()
        return loss

    def optimize(self,
                 optimal_point,
                 max_iter: int = 10000,
                 log_per: int = 1000):
        estimation_error_list = []
        value_distance_list = []
        optimal_value = self.compute_value(optimal_point)

        init_point = self.x.clone().detach().numpy()
        actual_d = np.linalg.norm(optimal_point - init_point)

        outer = tqdm(total=max_iter, position=0, desc="Iterations")
        log = tqdm(total=0, position=2, bar_format='{desc}')

        iteration = 0
        while iteration < max_iter:
            self.optimizer.zero_grad()
            loss = self.loss_with_oracle_call()
            loss.backward()
            loss_value = loss.item()
            self.optimizer.step(closure=self.closure)

            # First Step
            if iteration == 0:
                self.best_value = loss_value

            point_distances = [np.linalg.norm(self.best_point.numpy() - optimal_point) for _ in self.last_step_losses]
            value_distances = [self.best_value - optimal_value for _ in self.last_step_losses]
            value_distance_list.extend(value_distances)
            # value_distance_list.extend([lv - optimal_value for lv in self.last_step_losses])

            if self.optimizer.has_d_estimator():
                estimation_error = [self.optimizer.calculate_d_estimation_error(actual_d) for _ in self.last_step_losses]
                estimation_error_list.extend(estimation_error)

            if iteration % log_per == 0:
                log.set_description_str(f"Iter: {iteration}, \t"
                                        f"Point Distance (d_i) from Optimal: {'%.6f' % point_distances[-1]}, \t"
                                        f"Value Distance from Optimal: \t{'%.6f' % value_distances[-1]}")

            if loss_value < self.best_value:
                self.best_point = self.x.clone().detach()
                self.best_value = loss_value

            outer.update(len(self.last_step_losses))
            iteration += len(self.last_step_losses)
            self.last_step_losses.clear()

        return estimation_error_list, value_distance_list

    def optimal_value(self):
        return self.best_value

    def set_optimizer(self, optimizer: Optimizer):
        self.optimizer = optimizer

    def loss_with_oracle_call(self):
        loss = self.loss()
        self.last_step_losses.append(loss.item())
        return loss

    @abstractmethod
    def loss(self):
        raise NotImplementedError("must override loss")

    @abstractmethod
    def compute_value(self, point: np.ndarray):
        raise NotImplementedError("must override compute_value")

    def params(self):
        return [self.x]
